import numpy as np
import h5py
import pickle
from pytorch3d.io import save_ply
import torch

dataset = 'abo'
pcs = np.array(h5py.File('abo_pc_v2_val.h5','r')['data'])
texts = pickle.load(open('abo_text_v2_val.pkl','rb'))
idx = 114
save_ply('./examples/{}_{}.ply'.format(dataset,idx), torch.Tensor(pcs[idx]))

dataset = 'text2shape'
pcs = np.array(h5py.File('text2shape_pc_v1_val.h5','r')['data'])
texts = pickle.load(open('text2shape_text_v3_val.pkl','rb'))
idx = 102
# save_ply('./examples/{}_{}.ply'.format(dataset,idx), torch.Tensor(pcs[idx]))
save_ply('./final_captioning/{}_{}.ply'.format(dataset,idx), torch.Tensor(pcs[idx]))

dataset = 'shapeglot'
pcs = np.array(h5py.File('shapeglot_pc_v2_val.h5','r')['data'])
texts = pickle.load(open('shapeglot_text_v2_val.pkl','rb'))
idx = 102
save_ply('./final_captioning/{}_{}.ply'.format(dataset,idx), torch.Tensor(pcs[idx]))
# save_ply('./examples/{}_{}.ply'.format(dataset,idx), torch.Tensor(pcs[idx]))

swivel_idx = []
for i in range(len(texts)):
    sent = texts[i]
    for sen in sent:
        if 'swivel' in sen:
            swivel_idx.append(i)
            break
for idx in swivel_idx:
    save_ply('./examples/{}_{}.ply'.format(dataset,idx), torch.Tensor(pcs[idx])) 